In [ ]:
epochs = 10

Part 6 - Federated Learning on MNIST using a CNN

Upgrade to Federated Learning in 10 Lines of PyTorch + PySyft

Context

Federated Learning is a very exciting and upsurging Machine Learning technique that aims at building systems that learn on decentralized data. The idea is that the data remains in the hands of its producer (which is also known as the worker), which helps improving privacy and ownership, and the model is shared between workers. One immediate application is for example to predict the next word on your mobile phone when you write text: you don't want the data used for training — i.e. your text messages — to be sent to a central server.

The rise of Federated Learning is therefore tightly connected to the spread of data privacy awareness, and the GDPR in EU which enforces data protection since May 2018 has acted as a catalyst. To anticipate on regulation, large actors like Apple or Google have started investing massively in this technology, especially to protect the mobile users' privacy, but they have not made their tools available. At OpenMined, we believe that anyone willing to conduct a Machine Learning project should be able to implement privacy preserving tools with very little effort. We have built tools for encrypting data in a single line as mentioned in our blog post and we now release our Federated Learning framework which leverage the new PyTorch 1.0 version to provide an intuitive interface to building secure and scalable models.

In this tutorial, we'll use directly the canonical example of training a CNN on MNIST using PyTorch and show how simple it is to implement Federated Learning with it using our PySyft library. We will go through each part of the example and underline the code which is changed.

You can also find this material in our blogpost.

Authors:

Ok, let's get started!

Imports and model specifications

First we make the official imports


In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

And than those specific to PySyft. In particular we define remote workers alice and bob.


In [ ]:
import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob")  # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice")  # <-- NEW: and alice

We define the setting of the learning task


In [ ]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = epochs
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

Data loading and sending to workers

We first load the data and transform the training Dataset into a Federated Dataset split across the workers using the .federate method. This federated dataset is now given to a Federated DataLoader. The test dataset remains unchanged.


In [ ]:
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

CNN specification

Here we use exactly the same CNN as in the official example.


In [ ]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Define the train and test functions

For the train function, because the data batches are distributed across alice and bob, you need to send the model to the right location for each batch. Then, you perform all the operations remotely with the same syntax like you're doing local PyTorch. When you're done, you get back the model updated and the loss to look for improvement.


In [ ]:
def train(args, model, device, federated_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
                100. * batch_idx / len(federated_train_loader), loss.item()))

The test function does not change!


In [ ]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

Launch the training !


In [ ]:
%%time
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
    train(args, model, device, federated_train_loader, optimizer, epoch)
    test(args, model, device, test_loader)

if (args.save_model):
    torch.save(model.state_dict(), "mnist_cnn.pt")

Et voilà! Here you are, you have trained a model on remote data using Federated Learning!

One Last Thing

I know there's a question you're dying to ask: how long does it takes to do Federated Learning compared to normal PyTorch?

The computation time is actually less than twice the time used for normal PyTorch execution! More precisely, it takes 1.9 times longer, which is very little compared to the features we were able to add.

Conclusion

As you observe, we modified 10 lines of code to upgrade the official Pytorch example on MNIST to a real Federated Learning setting!

Of course, there are dozen of improvements we could think of. We would like the computation to operate in parallel on the workers and to perform federated averaging, to update the central model every n batches only, to reduce the number of messages we use to communicate between workers, etc. These are features we're working on to make Federated Learning ready for a production environment and we'll write about them as soon as they are released!

You should now be able to do Federated Learning by yourself! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

Star PySyft on GitHub

The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.

Pick our tutorials on GitHub!

We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community!

Join a Code Project!

The best way to contribute to our community is to become a code contributor! If you want to start "one off" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked Good First Issue.

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!


In [ ]: